Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces GPU support and refactors data handling to separate predictors and forcings throughout the training pipeline. Key changes include adding CUDA dependencies, updating configuration objects with device selectors, and modifying data loaders, splitters, and model forward passes to accommodate a new nested tuple input structure. Feedback highlights a mathematical error in the R-squared calculation, potential shape mismatches and incorrect NaN masking in the epoch loop, and several instances of dead code or typos. Additionally, a logic error was identified in a warning check within the data preparation module.
src/losses/loss_fn.jl
Outdated
| function loss_fn(ŷ, y, y_nan, ::Val{:r2}) | ||
| r = cor(ŷ[y_nan], y[y_nan]) | ||
| return r * r | ||
| return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(ŷ[y_nan])).^2) |
There was a problem hiding this comment.
The R-squared calculation is incorrect. The denominator should use the mean of the observed values (y), not the predicted values (ŷ). The standard definition of R² is
return 1 - sum((y[y_nan] .- ŷ[y_nan]).^2) / sum((y[y_nan] .- mean(y[y_nan])).^2)
src/training/epoch.jl
Outdated
| is_no_nan = falses(length(first(y))) |> cfg.gdev | ||
| for vec in y | ||
| is_no_nan = is_no_nan.|| .!isnan.(vec) | ||
| end |
There was a problem hiding this comment.
This logic has two significant issues:
- Shape Mismatch:
falses(length(first(y)))creates a 1D array. If the targets are multi-dimensional (e.g.,(time, batch)), the bitwise OR operation.||will fail. Usesize(first(y))instead oflength. - Incorrect Masking: Computing a single global
is_no_nanmask by ORing all targets is problematic. If target A has a NaN at an index where target B is valid, the global mask will betrueat that index. Consequently, the loss for target A will be computed using the NaN value, resulting in aNaNtotal loss. Masks should be computed and applied per-target.
| @warn "Note that you don't have target names." | ||
| end | ||
| return predictors_forcing, targets | ||
| return predictors, forcings, targets |
There was a problem hiding this comment.
The function now returns predictors and forcings separately, but the warning check at line 115 (visible in context) still references predictors_forcing. Since predictors_forcing is initialized as an empty array at line 89 and never populated in the new logic, this warning will be triggered on every call. The check should be updated to verify if both predictors and forcings are empty.
src/losses/compute_loss.jl
Outdated
| y_t = y[target]# _get_target_y(y, target) | ||
| ŷ_t = ŷ[target]#_get_target_ŷ(ŷ, y_t, target) |
There was a problem hiding this comment.
There is a typo in line 109: ŷ (y with combining circumflex) is used instead of the argument ŷ (U+0177) defined at line 105. While Julia normalizes identifiers to NFC, mixing these characters is confusing and can lead to issues in environments with different normalization rules. Additionally, the commented-out code should be removed.
y_t = y[target]
ŷ_t = ŷ[target]
src/training/initialization.jl
Outdated
| ps, st = LuxCore.setup(Random.default_rng(), model) |> cfg.gdev | ||
| ps = ps |> cfg.gdev |
There was a problem hiding this comment.
indeed, line 20 is not needed here.
src/data/prepare_data.jl
Outdated
| return (data(predictors_forcing), data(targets)) | ||
| dev = cfg.gdev | ||
| targets_nt = NamedTuple([target => dev(Array(data(target))) for target in targets]) | ||
| forcings_nt = NamedTuple([forcing => dev(Array(data(forcing))) for forcing in forcings]) |
There was a problem hiding this comment.
we should do dev/ Array at the batch loader level. Up to this point data could still be lazy.
| ŷ_train, _ = model(x_train, ps, LuxCore.testmode(st)) | ||
| ŷ_val, _ = model(x_val, ps, LuxCore.testmode(st)) | ||
| ŷ_train, _ = model((cfg.cdev(x_train), cfg.cdev(forcings_train)), cfg.cdev(ps), LuxCore.testmode(st)) | ||
| ŷ_val, _ = model((cfg.cdev(x_val), cfg.cdev(forcings_val)), cfg.cdev(ps), LuxCore.testmode(st)) |
There was a problem hiding this comment.
I think we can evaluate this still on the GPU side and just pipe the result of into the cfg.dev function.
src/training/epoch.jl
Outdated
| for (x, y) in cfg.gdev(loader) | ||
| is_no_nan = falses(length(first(y))) |> cfg.gdev |
There was a problem hiding this comment.
the cfg.gdev(loader) is already moving the data into the gpu device, the second line should be an operation on the gpu side already, hence |> cfg.gdev should not be needed, in principle
|
comments are being addressed in #257 |
…at needs to be done in the outer loop, refactoring genericHybrid is needed for that
|
Main things I changed here are:
I think the main thing that needs work on here and/or #257 is device switching, namely that device switch should happen at the batch level. CUDA needs contiguous arrays anyways to work, so views are a no no. Meaning we have to allocate a new array at every batch anyways, so may as well do the device switching there. I worked here before I saw all the comments. So I'll now switch to #257 |
|
We can also continue here, whatever merge branches is easier. Both are still open. I just didn't wanna to have merge conflicts just in case you did local work 😌, as you have done 👍. |
| @testset "_compute_loss" begin | ||
| # Test data setup | ||
| ŷ = Dict(:var1 => [1.0, 2.0, 3.0], :var2 => [2.0, 3.0, 4.0]) | ||
| y(target) = target == :var1 ? [1.1, 1.9, 3.2] : [1.8, 3.1, 3.9] |
There was a problem hiding this comment.
doing y(target) was a intended to mirror AxisKeys syntax, although admittedly it would be better to do an independent test for that, and similarly for DD.
There was a problem hiding this comment.
For now I've written the interface for loss functions to accept named tuples. So I may have to add a dispatch for callables.
There was a problem hiding this comment.
These should work again.
|
On regards to data:
|
| @@ -2,6 +2,10 @@ function run_epoch!(loader, model, ps, st, train_state, cfg::TrainConfig) | |||
| loss_fn = build_loss_fn(model, cfg) | |||
|
|
|||
| for (x, y) in loader | |||
There was a problem hiding this comment.
I think is better to do cfg.gdev(loader), is just a wrapper at this point, I think. Then, when loop over, the (x,y) will be sent to the gpu. If it works now, we can come back to this later.
There was a problem hiding this comment.
It works to do cfg.gdev(loader) except for dimensional data, it's still lazy and is not a simple array like the rest.
|
Most things have been addressed. Docs still have problems. |
|
I suppose docs issues are related to the new prepare data step. |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request significantly refactors the data handling and training pipeline to better separate predictors and forcings, and to improve GPU compatibility. Key changes include updating data structures (prepare_data, split_data) to explicitly distinguish between predictors and forcings, integrating Lux's gpu_device and cpu_device for managing device transfers during training and saving, and adapting various training functions (run_epoch!, evaluate_epoch, save_final!) to these new data formats and device considerations. The r2 loss calculation has been redefined, and new utility functions like collect_dim_data and valid_mask are introduced. Review comments highlight critical areas for improvement, including addressing type instability in TrainConfig's device fields, ensuring consistent CPU transfer for model states during evaluation, improving the robustness of data access in to_named_tuple, optimizing data collection in collect_dim_data for efficiency, making the R² loss calculation GPU-friendly, and enhancing semantic clarity in loss computation.
| "Select a gpu_device or default to cpu if none available" | ||
| gdev = gpu_device() | ||
|
|
||
| "Set the `cpu_device`, useful for sending back to the cpu model parameters" | ||
| cdev = cpu_device() |
There was a problem hiding this comment.
The fields gdev and cdev are untyped in the TrainConfig struct. In Julia, untyped fields lead to type instability, which can significantly degrade performance because the compiler cannot specialize functions using these fields. Since these are used frequently for device transfers during training, it is highly recommended to provide type annotations, such as Lux.AbstractDevice or using type parameters.
|
|
||
| function to_named_tuple(ka, target_names) | ||
| arrays = [Array(ka(variable = k)) for k in target_names] | ||
| arrays = [Array(ka[k]) for k in target_names] |
There was a problem hiding this comment.
The change from ka(variable = k) to ka[k] will break if ka is a KeyedArray, as KeyedArray indexing typically requires dimension names or positional indices. While this might work if ka is now a NamedTuple due to changes in prepare_data, it makes the function less robust if it's still intended to handle KeyedArray inputs.
| x_col = Array(x[1]) | ||
| forcing_nt = NamedTuple([k => Array(v) for (k, v) in pairs(x[2])]) | ||
| targets_nt = NamedTuple([k => Array(v) for (k, v) in pairs(y[1])]) | ||
| masks_nt = NamedTuple([k => Array(v) for (k, v) in pairs(y[2])]) |
There was a problem hiding this comment.
The construction of forcing_nt, targets_nt, and masks_nt using a loop and the NamedTuple constructor is inefficient. Since the inputs are already NamedTuples (as per the changes in prepare_data), you can use map which is more idiomatic and performant in Julia. Additionally, calling Array(v) on every batch is redundant if the data is already on the CPU.
x_col = x[1]
forcing_nt = map(identity, x[2])
targets_nt = map(identity, y[1])
masks_nt = map(identity, y[2])There was a problem hiding this comment.
The reason behind is is if the inputs are DimArrays they need to be collected explicitly, thus the Array calls
| function loss_fn(ŷ, y, y_nan, ::Val{:r2}) | ||
| r = cor(ŷ[y_nan], y[y_nan]) | ||
| return r * r | ||
| return 1 - sum((y[y_nan] .- ŷ[y_nan]) .^ 2) / sum((y[y_nan] .- mean(y[y_nan])) .^ 2) |
There was a problem hiding this comment.
The implementation of R² using logical indexing (e.g., y[y_nan]) is inefficient on GPUs because it often triggers scalar indexing, which is extremely slow. A more GPU-friendly approach is to use the mask to zero out invalid entries and then perform vectorized reductions. Also, note that this change shifts the definition of R² from squared Pearson correlation to the coefficient of determination, which might be a breaking change for users.
function loss_fn(ŷ, y, y_nan, ::Val{:r2})
ss_res = sum(abs2, (y .- ŷ) .* y_nan)
y_mean = sum(y .* y_nan) / sum(y_nan)
ss_tot = sum(abs2, (y .- y_mean) .* y_nan)
return 1 - ss_res / ss_tot
end| y_nan_t = _get_target_y(y_nan, target) | ||
| _apply_loss(ŷ_t, y_t, y_nan_t, loss_spec) | ||
| # _apply_loss(ŷ_t, y_t, _get_target_nan(y_nan, target), loss_spec) |
There was a problem hiding this comment.
The use of _get_target_y for a variable named y_nan is semantically confusing. It is better to use _get_target_nan, which was specifically defined for this purpose. Also, please remove the commented-out code to keep the codebase clean.
y_nan_t = _get_target_nan(y_nan, target)
_apply_loss(ŷ_t, y_t, y_nan_t, loss_spec)
No description provided.